POSTAL_TO_STATE = list('AL'='Alabama', 'AK'='Alaska', 'AS'='American Samoa',
                       'AZ'='Arizona', 'AR'='Arkansas', 'CA'='California',
                       'CO'='Colorado', 'CT'='Connecticut', 'DE'='Delaware',
                       'DC'='District of Columbia', 'FL'='Florida',
                       'GA'='Georgia', 'GU'='Guam', 'HI'='Hawaii',
                       'ID'='Idaho', 'IL'='Illinois', 'IN'='Indiana',
                       'IA'='Iowa', 'KS'='Kansas', 'KY'='Kentucky',
                       'LA'='Louisiana', 'ME'='Maine', 'MD'='Maryland',
                       'MA'='Massachusetts', 'MI'='Michigan', 'MN'='Minnesota',
                       'MS'='Mississippi', 'MO'='Missouri', 'MT'='Montana',
                       'NE'='Nebraska', 'NV'='Nevada', 'NH'='New Hampshire',
                       'NJ'='New Jersey', 'NM'='New Mexico', 'NY'='New York',
                       'NC'='North Carolina', 'ND'='North Dakota',
                       'MP'='Northern Mariana Islands', 'OH'='Ohio',
                       'OK'='Oklahoma', 'OR'='Oregon', 'PA'='Pennsylvania',
                       'PR'='Puerto Rico', 'RI'='Rhode Island', 'SC'='South Carolina',
                       'SD'='South Dakota', 'TN'='Tennessee',
                       'TX'='Texas', 'UT'='Utah', 'VT'='Vermont', 'VI'='Virgin Islands',
                       'VA'='Virginia', 'WA'='Washington', 'WV'='West Virginia',
                       'WI'='Wisconsin', 'WY'='Wyoming')

states = c("al", "ak", "az", "ar", "ca", "co", "ct", "de", "fl", "ga", "hi",
           "id", "il", "in", "ia", "ks", "ky", "la", "me", "md", "ma", "mi",
           "mn", "ms", "mo", "mt", "ne", "nv", "nh", "nj", "nm", "ny", "nc",
           "nd", "oh", "ok", "or", "pa", "ri", "sc", "sd", "tn", "tx", "ut",
           "vt", "va", "wa", "wv", "wi", "wy")

BASE_DAILY_URL = paste0(
      'https://raw.githubusercontent.com/google-research/open-covid-19-data/',
      'master/data/exports/search_trends_symptoms_dataset/',
      'United%20States%20of%20America/subregions/{state}/',
      '2020_US_{state_underscore}_daily_symptoms_dataset.csv')
cache_data_list = list()
signal_description_df = tribble(
    ~signal,            ~description,
    'Podalgia',                         'pain in the foot',
    'Anosmia',                          'loss of smell',
    'Purpura',                          "red/purple skin spots; 'blood spots'",
    'Radiculopathy',                    'pinched nerve',
    'Ageusia',                          'loss of taste',
    'Erythema chronicum migrans',       'expanding rash early in lyme disease',
    'Photodermatitis',                  'allergic rash that reqs light',
)
expand_state_name = function(state) {
  state_name = POSTAL_TO_STATE[[str_to_upper(state)]]
  return(state_name)
}

load_state_data = function(state) {
  if (state %in% names(cache_data_list)) return (cache_data_list[[state]])
  # Check whether there is a cached version
  state_fname = sprintf('cache/%s.csv', state)
  # if there isn't, then download
  if (!file.exists(state_fname)) {
    state_name = expand_state_name(state)
    message(sprintf('Downloading data for %s...', state_name))
    state_name_underscore = str_replace_all(state_name, ' ', '_')
    STATE_DAILY_URL = str_replace_all(BASE_DAILY_URL,
                                   fixed('{state}'), state_name)
    STATE_DAILY_URL = str_replace_all(STATE_DAILY_URL,
                                   fixed('{state_underscore}'),
                                   state_name_underscore)
    STATE_DAILY_URL = str_replace_all(STATE_DAILY_URL,
                                   fixed(' '),
                                   '%20')
    download.file(STATE_DAILY_URL, state_fname)
  }
  single_state = readr::read_csv(state_fname)
  cache_data_list[[state]] <<- single_state
  return (single_state)
}


pull_data_state = function(state, symptom) {
  single_state = load_state_data(state)
  unique(single_state$sub_region_2_code)
  single_state_counties = single_state[!is.na(single_state$sub_region_2_code),]
  selected_symptom = paste0('symptom:', symptom)
  single_state_symptom = single_state_counties[,c('sub_region_2_code',
                                                  'date',
                                                  selected_symptom)]
  # Shape into what we want
  colnames(single_state_symptom) = c('geo_value', 'time_value', 'value')
  single_state_symptom = single_state_symptom %>% filter (
      !is.na(value),
    )
  single_state_symptom = single_state_symptom %>% transmute (
      geo_value = sprintf('%05d', as.numeric(geo_value)),
      signal = symptom,
      time_value = time_value,
      direction = NA,
      issue = lubridate::today(),
      lag = issue - time_value,
      value = value,
      stderr = NA,
      sample_size = NA,
      data_source = 'google_symptoms',
    )
}

Summary

Initial ingestion and exploration

if (file.exists('symptom_df.RDS')) {
  symptom_df = readRDS('symptom_df.RDS')
  symptom_names = unique(symptom_df$signal)
} else {
  dir.create('./cache/')
  ak = load_state_data('ak')
  symptom_cols = colnames(ak)[
                    str_detect(colnames(ak), 'symptom:')]
  symptom_names = str_replace(symptom_cols, fixed('symptom:'), '')

  symptom_df_list = vector('list', length(symptom_names))
  names(symptom_df_list) = symptom_names

  for (symptom in symptom_names) {
    cat(symptom, '...\n')
    states_list = vector('list', length(states))
    for (idx in 1:length(states)) {
      state = states[idx]
      states_list[[idx]] = pull_data_state(state, symptom)
    }
    symptom_df_list[[symptom]] = bind_rows(states_list)
  }
  symptom_df = bind_rows(symptom_df_list)
  saveRDS(symptom_df, 'symptom_df.RDS')
}
start_day = "2020-03-01"
end_day = "2020-09-15"

df_inum = covidcast_signal(data_source = "jhu-csse",
                   signal = "confirmed_7dav_incidence_num",
                   start_day = start_day, end_day = end_day)

case_num = 500
geo_values = df_inum %>% group_by(geo_value) %>%
  summarize(total = sum(value)) %>%
  filter(total >= case_num) %>% pull(geo_value)
df_inum_act = df_inum %>% filter(geo_value %in% geo_values)
symptom_df_act = symptom_df %>% filter (
  geo_value %in% geo_values,
)

Here we plot the availaibility of each symptom over time (proportion is percentage of counties for which the symptom was available). We see that for each signal, the availability level is consistent over time, subject to a strong weekend effect.

availability_df = symptom_df_act %>% group_by (
  time_value,
  signal,
) %>% summarize (
  prop_available = n() / length(geo_values),
) %>% ungroup (
)
## `summarise()` regrouping output by 'time_value' (override with `.groups` argument)
plt = (ggplot(availability_df)
       + geom_line(aes(x=time_value,
                       y=prop_available,
                       group=factor(signal)),
                   color='dodgerblue4',
                   size=0.1)
       + ggtitle(paste('Within-single availability stable over time,',
                       'weekend effects'))
       )
plt

The symptoms for which data is most sparse are:

most_missing = availability_df %>% group_by (
  signal,
) %>% summarize (
  avg_available = mean(prop_available)
) %>% ungroup (
) %>% filter (
  avg_available <= 0.05
) %>% arrange (
  avg_available,
)
## `summarise()` ungrouping output (override with `.groups` argument)
print(most_missing)
## # A tibble: 24 x 2
##    signal                  avg_available
##    <chr>                           <dbl>
##  1 Viral pneumonia                0.0219
##  2 Aphonia                        0.0266
##  3 Crackles                       0.0281
##  4 Burning Chest Pain             0.0291
##  5 Urinary urgency                0.0296
##  6 Polydipsia                     0.0297
##  7 Photodermatitis                0.0305
##  8 Shallow breathing              0.0331
##  9 Dysautonomia                   0.0332
## 10 Allergic conjunctivitis        0.0346
## # … with 14 more rows

For the signal that is most sparsely available, the number of counties at which it tends to be available daily is:

print(min(most_missing$avg_available) * length(geo_values))
## [1] 28.94071

Based on this, we leave all the symptoms in for the full correlations analysis.

Correlations

cor_list = vector('list', length(symptom_names))
names(cor_list) = symptom_names

if (file.exists('cor_df.RDS')) {
  cor_df = readRDS('cor_df.RDS')
} else {
  for (symptom in symptom_names) {
    cat(symptom, '...\n')
    df_cor1 = covidcast_cor(symptom_df_act %>% filter(signal == symptom),
                            df_inum_act,
                            by = "time_value",
                            method = "spearman")
    df_cor1['signal'] = symptom
    cor_list[[symptom]] = df_cor1
  }
  cor_df = bind_rows(cor_list)
  saveRDS(cor_df, 'cor_df.RDS')
}
cor_df = cor_df %>% left_join(
  signal_description_df,
  on='signal',
)
## Joining, by = "signal"

Correlation over time: all symptoms

## Warning: Removed 2532 row(s) containing missing values (geom_path).

Correlation over time: largest single-day correlation symptoms

When we discuss the “size” of a correlation, we consider the absolute value of correlation.

top_cor_signals = plot_cor_df %>% group_by (
    signal,
  ) %>% filter (
    abs(value) == max(abs(value), na.rm=TRUE),
  ) %>% ungroup (
  ) %>% arrange(
    -abs(value),
  ) %>% head (
    5,
  )
top_cor_sum_stats = plot_cor_df %>% filter (
    signal %in% top_cor_signals$signal,
  ) %>% group_by (
    signal,
  ) %>% summarize (
    min = min(value, na.rm=TRUE),
    quart1 = quantile(value, 0.25, na.rm=TRUE),
    med = median(value, na.rm=TRUE),
    mean = mean(value, na.rm=TRUE),
    quart3 = quantile(value, 0.75, na.rm=TRUE),
    max = max(value, na.rm=TRUE),
  ) %>% ungroup (
  )
## `summarise()` ungrouping output (override with `.groups` argument)
print('Symptoms with the largest all-time correlation:')
## [1] "Symptoms with the largest all-time correlation:"
print(top_cor_signals %>% left_join(top_cor_sum_stats, on='signal')
        %>% select(-time_value, -value),
      width=100)
## Joining, by = "signal"
## # A tibble: 5 x 8
##   signal                     description                              min quart1
##   <chr>                      <chr>                                  <dbl>  <dbl>
## 1 Viral pneumonia            <NA>                                 -1      -0.396
## 2 Anosmia                    loss of smell                        -0.0788  0.311
## 3 Ageusia                    loss of taste                        -0.423   0.185
## 4 Erythema chronicum migrans expanding rash early in lyme disease -0.710  -0.562
## 5 Photodermatitis            allergic rash that reqs light        -0.709  -0.461
##        med    mean quart3   max
##      <dbl>   <dbl>  <dbl> <dbl>
## 1 -0.00360 -0.0600  0.295 0.668
## 2  0.556    0.499   0.721 0.833
## 3  0.478    0.412   0.672 0.797
## 4 -0.307   -0.344  -0.205 0.221
## 5 -0.343   -0.322  -0.172 0.282
plt = (ggplot(plot_cor_df)
       + geom_line(aes(x=time_value,
                       y=value,
                       group=factor(signal)),
                   data=plot_cor_df %>% filter (
                      !signal %in% top_cor_signals$signal
                   ),
                   color='cornsilk',
                   size=0.1,
                   alpha=1.0)
       + geom_line(aes(x=time_value,
                       y=value,
                       group=factor(signal),
                       colour=factor(signal)
                       ),
                   data=plot_cor_df %>% filter (
                      signal %in% top_cor_signals$signal,
                   ),
                   #color='darkorange',
                   size=0.3)
       + ylab('rank correlation')
       + scale_x_date(breaks=lubridate::ymd(c('2020-03-01',
            '2020-03-15', '2020-04-01', '2020-04-15', '2020-05-01',
            '2020-05-15', '2020-06-01', '2020-06-15', '2020-07-01',
            '2020-07-15', '2020-08-01', '2020-08-15',
            '2020-09-01', '2020-09-15')))
       + theme(axis.text.x = element_text(angle = 45))
       + ggtitle("Top 5 signals by all-time max(|corr|)")
       )
plt
## Warning: Removed 2502 row(s) containing missing values (geom_path).
## Warning: Removed 30 row(s) containing missing values (geom_path).

Correlation over time: “consistently away from zero” symptoms

## `summarise()` ungrouping output (override with `.groups` argument)
## [1] "Symptoms that consistently stay away from 0 correlation:"
## Joining, by = "signal"
## # A tibble: 5 x 8
##   signal                 description          min quart1    med   mean quart3
##   <chr>                  <chr>              <dbl>  <dbl>  <dbl>  <dbl>  <dbl>
## 1 Podalgia               pain in the foot -0.431  -0.270 -0.225 -0.232 -0.182
## 2 Restless legs syndrome <NA>             -0.533  -0.406 -0.311 -0.308 -0.225
## 3 Hair loss              <NA>              0.0440  0.261  0.320  0.305  0.372
## 4 Radiculopathy          pinched nerve    -0.477  -0.336 -0.255 -0.251 -0.158
## 5 Hyperpigmentation      <NA>              0.0267  0.220  0.294  0.295  0.386
##       max
##     <dbl>
## 1 -0.0781
## 2 -0.0672
## 3  0.432 
## 4 -0.0425
## 5  0.485
## Warning: Removed 2502 row(s) containing missing values (geom_path).
## Warning: Removed 30 row(s) containing missing values (geom_path).

Correlation across location: largest single-day correlation symptoms

if (file.exists('geo_cor_df.RDS')) {
  geo_cor_df = readRDS('geo_cor_df.RDS')
} else {
  geo_cor_list = vector('list', length(symptom_names))
  names(geo_cor_list) = symptom_names

  for (symptom in symptom_names) {
    cat(symptom, '...\n')
    df_cor1 = covidcast_cor(symptom_df_act %>% filter(signal == symptom),
                            df_inum_act,
                            by = "geo_value",
                            method = "spearman")
    df_cor1['signal'] = symptom
    geo_cor_list[[symptom]] = df_cor1
  }
  geo_cor_df = bind_rows(geo_cor_list)
  saveRDS(geo_cor_df, 'geo_cor_df.RDS')
}
geo_cor_df = geo_cor_df %>% left_join(
  signal_description_df,
  on='signal',
)
## Joining, by = "signal"

The sign of the correlation with cases is fairly homogeneous in geography for viral pneumonia, anosmia, and ageusia, which increases my confidence that they will serve well as signals for predicting cases in a global model. I do not fully understand why there is a negative correlation between viral pneumonia and cases. Also important to note is that this handful of high-signal symptoms only cover a smattering of counties, mostly high-population areas – 50-100 counties out of 3000 total. For context, the modeling team’s county-level forecasts target roughly 200 counties, as of September 2020.

for (symptom in top_cor_signals$signal) {
  df_cor2 = geo_cor_df %>% filter (signal == symptom)
  df_cor2$time_value = min_available_time
  df_cor2$issue = min_available_time
  attributes(df_cor2)$geo_type = 'county'
  class(df_cor2) = c("covidcast_signal", "data.frame")
  n_available_county = df_cor2 %>% filter (!is.na(value)) %>% nrow()

  # Plot choropleth maps, using the covidcast plotting functionality
  title_text = sprintf("Correlations between cases and %s (%d counties)",
                             symptom, n_available_county)
  if (!is.na(df_cor2$description[1])) {
    title_text = paste0(title_text, '\n', sprintf('(%s)', df_cor2$description[1]))
  } 
  print(plot(df_cor2,
             title = title_text,
            range = c(-1, 1), choro_col = c("orange","lightblue", "purple")))
}

Correlation across location: “consistently away from zero” symptoms

The sign of the correlation with cases for “consistently away from zero” symptoms is also fairly homogeneous in location. However, the main takeaway, in my opinion, of these plots is to show the greater geographical coverage of this set of symptoms compared to the high-signal set of symptoms. I am led to believe that “consistently away from zero” is tied to noise-level, which is affected by sample size, which is greater for these (presumably common) symptoms.

for (symptom in top_min_cor$signal) {
  df_cor2 = geo_cor_df %>% filter (signal == symptom)
  df_cor2$time_value = min_available_time
  df_cor2$issue = min_available_time
  attributes(df_cor2)$geo_type = 'county'
  class(df_cor2) = c("covidcast_signal", "data.frame")
  n_available_county = df_cor2 %>% filter (!is.na(value)) %>% nrow()

  # Plot choropleth maps, using the covidcast plotting functionality
  title_text = sprintf("Correlations between cases and %s (%d counties)",
                             symptom, n_available_county)
  if (!is.na(df_cor2$description[1])) {
    title_text = paste0(title_text, '\n', sprintf('(%s)', df_cor2$description[1]))
  } 
  print(plot(df_cor2,
             title = title_text,
            range = c(-1, 1), choro_col = c("orange","lightblue", "purple")))
}

Rudimentary prediction problem

Here we use code liberally borrowed from an upcoming blog post to perform a prediction task.

Tried something but it doesn’t look too well :/

# Function to append shift values (lags or leads) to data frame
append_shifts = function(df, shifts) {
  # Make sure that we have a complete record of dates for each geo_value (fill
  # with NAs as necessary)
  df_all = df %>% group_by(geo_value) %>%
    summarize(time_value = seq.Date(as.Date(min(time_value)),
                                    as.Date(max(time_value)),
                                    by = "day")) %>% ungroup()
  df = full_join(df, df_all, by = c("geo_value", "time_value"))
  
  # Group by geo value, sort rows by increasing time
  df = df %>% group_by(geo_value) %>% arrange(time_value) 
  
  # Load over shifts, and add lag value or lead value
  for (shift in shifts) {
    fun = ifelse(shift < 0, lag, lead)
    varname = sprintf("value%+d", shift)
    df = mutate(df, !!varname := fun(value, n = abs(shift)))
  }
  
  # Ungroup and return
  return(ungroup(df))
}

# Some useful functions for transformations
Log = function(x, a = 0.01) log(x + a)
Exp = function(y, a = 0.01) exp(y) - a
Logit = function(x, a = 0.01) log((x + a) / (1 - x + a))
Sigmd = function(y, a = 0.01) (exp(y) * (1 + a) - a) / (1 + exp(y))
Id = function(x) x
 
# Transforms to consider, in what follows
trans = Id
inv_trans = Id

# Rescale factors for our signals: bring them all down to proportions (between
# 0 and 1)
rescale_symptom = 1e-2 # Originally b/t 0 and 100
rescale_case = 1e-5 # Originally a count per 100,000 people

# Consider only counties with at least 200 cumulative cases by Google's end
case_num = 200
geo_values = covidcast_signal("jhu-csse", "confirmed_cumulative_num",
                              "2020-05-14", "2020-05-14") %>%
  filter(value >= case_num) %>% pull(geo_value) 

# Fetch county-level Google and Facebook % CLI-in-community signals, and JHU
# confirmed case incidence proportion
if (!'symptom_df' %in% ls()) {
  symptom_df = readRDS('symptom_df.RDS')
}
start_day = "2020-04-11"
end_day = "2020-09-01"
anosmia = symptom_df %>% filter(signal == 'Anosmia') %>% 
  select(geo_value, time_value, value) %>%
  filter(geo_value %in% geo_values,
                 time_value >= start_day,
                 time_value <= end_day) 
ageusia = symptom_df %>% filter(signal == 'Ageusia') %>% 
  select(geo_value, time_value, value) %>%
  filter(geo_value %in% geo_values,
                 time_value >= start_day,
                 time_value <= end_day) 
viral_pneumonia = symptom_df %>% filter(signal == 'Viral pneumonia') %>% 
  select(geo_value, time_value, value) %>%
  filter(geo_value %in% geo_values,
                 time_value >= start_day,
                 time_value <= end_day) 
case = covidcast_signal("jhu-csse", "confirmed_7dav_incidence_prop",
                     start_day, end_day) %>%
  select(geo_value, time_value, value) %>%
  filter(geo_value %in% geo_values) 
geo_values_complete = intersect(intersect(intersect(anosmia$geo_value, ageusia$geo_value),
                                viral_pneumonia$geo_value), case$geo_value)
# Filter to complete counties, transform the signals, append 1-2 week lags to 
# all three, and also 1-2 week leads to case rates
lags = -1:-2 * 7
leads = 1:2 * 7
anosmia = anosmia %>% filter(geo_value %in% geo_values_complete) %>% 
  mutate(value = trans(value * rescale_symptom)) %>% 
  append_shifts(shifts = lags) 
ageusia = ageusia %>% filter(geo_value %in% geo_values_complete) %>% 
  mutate(value = trans(value * rescale_symptom)) %>% 
  append_shifts(shifts = lags) 
viral_pneumonia = viral_pneumonia %>% filter(geo_value %in% geo_values_complete) %>% 
  mutate(value = trans(value * rescale_symptom)) %>% 
  append_shifts(shifts = lags) 
case = case %>% filter(geo_value %in% geo_values_complete) %>% 
  mutate(value = trans(value * rescale_case)) %>% 
  append_shifts(shifts = c(lags, leads))

# Rename columns
colnames(anosmia) = sub("^value", "anosmia", colnames(anosmia))
colnames(ageusia) = sub("^value", "ageusia", colnames(ageusia))
colnames(viral_pneumonia) = sub("^value", "viral_pneumonia", colnames(viral_pneumonia))
colnames(case) = sub("^value", "case", colnames(case))

# Make one big matrix by joining these three data frames
z = full_join(full_join(full_join(anosmia, ageusia, by = c("geo_value", "time_value")),
              viral_pneumonia, by = c("geo_value", "time_value")),
              case, by = c("geo_value", "time_value"))

##### Analysis #####

# Use quantgen for LAD regression (this package supports quantile regression and
# more; you can find it on GitHub here: https://github.com/ryantibs/quantgen)
library(quantgen) 

res_list = vector("list", length = length(leads)) 
n = 14 # Number of trailing days to use for training set
verbose = TRUE # Print intermediate progress to console?

# Loop over lead, forecast dates, build models and record errors (warning: this
# computation takes a while)
for (i in 1:length(leads)) { 
  lead = leads[i]; if (verbose) cat("***", lead, "***\n")
  
  # Create a data frame to store our results. Code below populates its rows in a
  # way that breaks from typical dplyr operations, done for efficiency
  res_list[[i]] = z %>% 
    filter(between(time_value, as.Date(start_day) - min(lags) + lead, 
                   as.Date(end_day) - lead)) %>%
    select(geo_value, time_value) %>%
    mutate(err0 = as.double(NA), err1 = as.double(NA), err2 = as.double(NA), 
           #err3 = as.double(NA), err4 = as.double(NA),
           lead = lead) 
  valid_dates = unique(res_list[[i]]$time_value)
  
  for (j in 1:length(valid_dates)) {
    date = valid_dates[j]; if (verbose) cat(format(date), "... ")
    
    # Filter down to training set and test set
    z_tr = z %>% filter(between(time_value, date - lead - n, date - lead))
    z_te = z %>% filter(time_value == date)
    inds = which(res_list[[i]]$time_value == date)
    
    # Create training and test responses
    y_tr = z_tr %>% pull(paste0("case+", lead))
    y_te = z_te %>% pull(paste0("case+", lead))
    
    # Strawman model
    if (verbose) cat("0")
    y_hat = z_te %>% pull(case)
    res_list[[i]][inds,]$err0 = abs(inv_trans(y_hat) - inv_trans(y_te))
    
    # Cases only model
    if (verbose) cat("1")
    x_tr_case = z_tr %>% select(starts_with("case") & !contains("+"))
    x_te_case = z_te %>% select(starts_with("case") & !contains("+"))
    x_tr = x_tr_case; x_te = x_te_case # For symmetry wrt what follows 
    ok = complete.cases(x_tr, y_tr)
    if (sum(ok) > 0) {
      obj = quantile_lasso(as.matrix(x_tr[ok,]), y_tr[ok], tau = 0.5,
                           lambda = 0, stand = FALSE, lp_solver = "gurobi")
      y_hat = as.numeric(predict(obj, newx = as.matrix(x_te)))
      res_list[[i]][inds,]$err1 = abs(inv_trans(y_hat) - inv_trans(y_te)) 
    }
    
    # Cases and symptoms model
    if (verbose) cat("2")
    x_tr_anosmia = z_tr %>% select(starts_with("anosmia"))
    x_te_anosmia = z_te %>% select(starts_with("anosmia"))
    x_tr_ageusia = z_tr %>% select(starts_with("ageusia"))
    x_te_ageusia = z_te %>% select(starts_with("ageusia"))
    x_tr_viral_pneumonia = z_tr %>% select(starts_with("viral_pneumonia"))
    x_te_viral_pneumonia = z_te %>% select(starts_with("viral_pneumonia"))
    x_tr = cbind(x_tr_case, x_tr_anosmia, x_tr_ageusia, x_tr_viral_pneumonia)
    x_te = cbind(x_te_case, x_te_anosmia, x_te_ageusia, x_te_viral_pneumonia)
    ok = complete.cases(x_tr, y_tr)
    if (sum(ok) > 0) {
      obj = quantile_lasso(as.matrix(x_tr[ok,]), y_tr[ok], tau = 0.5,
                           lambda = 0, stand = FALSE, lp_solver = "gurobi")
      y_hat = as.numeric(predict(obj, newx = as.matrix(x_te)))
      err_vec = abs(inv_trans(y_hat) - inv_trans(y_te))
      res_list[[i]][inds,]$err2 = err_vec
    }
  }
}

# Bind results over different leads into one big data frame
res = do.call(rbind, res_list)

# Calculate the median of the scaled errors for the various model: that is, the 
# errors relative to the strawman's error
res_med = res %>% mutate(err1 = err1 / err0, err2 = err2 / err0) %>%
  select(-err0) %>% 
  tidyr::pivot_longer(names_to = "model", values_to = "err", 
                      cols = -c(geo_value, time_value, lead)) %>%
  group_by(time_value, lead, model) %>% 
  summarize(err = median(err, na.rm = TRUE)) %>%
  ungroup() %>% 
  mutate(lead = factor(lead, labels = paste(leads, "days ahead")),
         model = factor(model, labels = c("Cases", "Cases + Symptoms")))

saveRDS(res, 'res.RDS')
saveRDS(res_med, 'res_med.RDS')

ggplot(res_med, aes(x = time_value, y = err)) + 
  geom_line(aes(color = model)) + 
  geom_hline(yintercept = 1, linetype = 2, color = "gray") +
  facet_wrap(vars(lead)) + 
  labs(x = "Date", y = "Scaled error", title = "Id transform") +
  theme_bw() + theme(legend.pos = "bottom", legend.title = element_blank())